{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d298e185",
   "metadata": {},
   "source": [
    "The goal of this, and the companion (part 2) notebooks is to illustrate how one could use this library in the context of recommendation systems. In particular, this notebook and the scripts at the `wide_deep_for_recsys` dir are a response to this [issue](https://github.com/jrzaurin/pytorch-widedeep/issues/133). Therefore, we will use the [Kaggle notebook](https://www.kaggle.com/code/matanivanov/wide-deep-learning-for-recsys-with-pytorch) referred in that issue here.\n",
    "\n",
    "In order to keep the length of the notebook tractable, we will split this exercise in 2. In this first notebook we will prepare the [data](https://www.kaggle.com/datasets/prajitdatta/movielens-100k-dataset) in almost the exact same way as it is done in the Kaggle notebook and also show how one could use `pytorch-widedeep` to build a model almost identical to the one in that notebook. \n",
    "\n",
    "In a second notebook, we will show how one could use this library to implement other models, still following the same problem formulation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ebd9980d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import warnings\n",
    "\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from pytorch_widedeep.datasets import load_movielens100k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7cd76bce",
   "metadata": {},
   "outputs": [],
   "source": [
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0aed611e",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = Path(\"prepared_data\")\n",
    "if not save_path.exists():\n",
    "    save_path.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5de7a941",
   "metadata": {},
   "outputs": [],
   "source": [
    "data, users, items = load_movielens100k(as_frame=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7a288aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:\n",
    "# list_of_genres = items.columns.tolist()[-19:]\n",
    "list_of_genres = [\n",
    "    \"unknown\",\n",
    "    \"Action\",\n",
    "    \"Adventure\",\n",
    "    \"Animation\",\n",
    "    \"Children's\",\n",
    "    \"Comedy\",\n",
    "    \"Crime\",\n",
    "    \"Documentary\",\n",
    "    \"Drama\",\n",
    "    \"Fantasy\",\n",
    "    \"Film-Noir\",\n",
    "    \"Horror\",\n",
    "    \"Musical\",\n",
    "    \"Mystery\",\n",
    "    \"Romance\",\n",
    "    \"Sci-Fi\",\n",
    "    \"Thriller\",\n",
    "    \"War\",\n",
    "    \"Western\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "929a9712",
   "metadata": {},
   "source": [
    "Let's first start by loading the interactions, user and item data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f4c09273",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>movie_id</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>196</td>\n",
       "      <td>242</td>\n",
       "      <td>3</td>\n",
       "      <td>881250949</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>186</td>\n",
       "      <td>302</td>\n",
       "      <td>3</td>\n",
       "      <td>891717742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>22</td>\n",
       "      <td>377</td>\n",
       "      <td>1</td>\n",
       "      <td>878887116</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>244</td>\n",
       "      <td>51</td>\n",
       "      <td>2</td>\n",
       "      <td>880606923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>166</td>\n",
       "      <td>346</td>\n",
       "      <td>1</td>\n",
       "      <td>886397596</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  movie_id  rating  timestamp\n",
       "0      196       242       3  881250949\n",
       "1      186       302       3  891717742\n",
       "2       22       377       1  878887116\n",
       "3      244        51       2  880606923\n",
       "4      166       346       1  886397596"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "18c3faa0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>occupation</th>\n",
       "      <th>zip_code</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>85711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>53</td>\n",
       "      <td>F</td>\n",
       "      <td>other</td>\n",
       "      <td>94043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>23</td>\n",
       "      <td>M</td>\n",
       "      <td>writer</td>\n",
       "      <td>32067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>43537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>33</td>\n",
       "      <td>F</td>\n",
       "      <td>other</td>\n",
       "      <td>15213</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  age gender  occupation zip_code\n",
       "0        1   24      M  technician    85711\n",
       "1        2   53      F       other    94043\n",
       "2        3   23      M      writer    32067\n",
       "3        4   24      M  technician    43537\n",
       "4        5   33      F       other    15213"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "users.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1dbad7b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>movie_id</th>\n",
       "      <th>movie_title</th>\n",
       "      <th>release_date</th>\n",
       "      <th>video_release_date</th>\n",
       "      <th>IMDb_URL</th>\n",
       "      <th>unknown</th>\n",
       "      <th>Action</th>\n",
       "      <th>Adventure</th>\n",
       "      <th>Animation</th>\n",
       "      <th>Children's</th>\n",
       "      <th>...</th>\n",
       "      <th>Fantasy</th>\n",
       "      <th>Film-Noir</th>\n",
       "      <th>Horror</th>\n",
       "      <th>Musical</th>\n",
       "      <th>Mystery</th>\n",
       "      <th>Romance</th>\n",
       "      <th>Sci-Fi</th>\n",
       "      <th>Thriller</th>\n",
       "      <th>War</th>\n",
       "      <th>Western</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Toy Story (1995)</td>\n",
       "      <td>01-Jan-1995</td>\n",
       "      <td>NaN</td>\n",
       "      <td>http://us.imdb.com/M/title-exact?Toy%20Story%2...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>GoldenEye (1995)</td>\n",
       "      <td>01-Jan-1995</td>\n",
       "      <td>NaN</td>\n",
       "      <td>http://us.imdb.com/M/title-exact?GoldenEye%20(...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Four Rooms (1995)</td>\n",
       "      <td>01-Jan-1995</td>\n",
       "      <td>NaN</td>\n",
       "      <td>http://us.imdb.com/M/title-exact?Four%20Rooms%...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>Get Shorty (1995)</td>\n",
       "      <td>01-Jan-1995</td>\n",
       "      <td>NaN</td>\n",
       "      <td>http://us.imdb.com/M/title-exact?Get%20Shorty%...</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>Copycat (1995)</td>\n",
       "      <td>01-Jan-1995</td>\n",
       "      <td>NaN</td>\n",
       "      <td>http://us.imdb.com/M/title-exact?Copycat%20(1995)</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 24 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   movie_id        movie_title release_date  video_release_date  \\\n",
       "0         1   Toy Story (1995)  01-Jan-1995                 NaN   \n",
       "1         2   GoldenEye (1995)  01-Jan-1995                 NaN   \n",
       "2         3  Four Rooms (1995)  01-Jan-1995                 NaN   \n",
       "3         4  Get Shorty (1995)  01-Jan-1995                 NaN   \n",
       "4         5     Copycat (1995)  01-Jan-1995                 NaN   \n",
       "\n",
       "                                            IMDb_URL  unknown  Action  \\\n",
       "0  http://us.imdb.com/M/title-exact?Toy%20Story%2...        0       0   \n",
       "1  http://us.imdb.com/M/title-exact?GoldenEye%20(...        0       1   \n",
       "2  http://us.imdb.com/M/title-exact?Four%20Rooms%...        0       0   \n",
       "3  http://us.imdb.com/M/title-exact?Get%20Shorty%...        0       1   \n",
       "4  http://us.imdb.com/M/title-exact?Copycat%20(1995)        0       0   \n",
       "\n",
       "   Adventure  Animation  Children's  ...  Fantasy  Film-Noir  Horror  Musical  \\\n",
       "0          0          1           1  ...        0          0       0        0   \n",
       "1          1          0           0  ...        0          0       0        0   \n",
       "2          0          0           0  ...        0          0       0        0   \n",
       "3          0          0           0  ...        0          0       0        0   \n",
       "4          0          0           0  ...        0          0       0        0   \n",
       "\n",
       "   Mystery  Romance  Sci-Fi  Thriller  War  Western  \n",
       "0        0        0       0         0    0        0  \n",
       "1        0        0       0         1    0        0  \n",
       "2        0        0       0         1    0        0  \n",
       "3        0        0       0         0    0        0  \n",
       "4        0        0       0         1    0        0  \n",
       "\n",
       "[5 rows x 24 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "items.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3cb7bbc5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>movie_id</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>num_watched</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>168</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>172</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>165</td>\n",
       "      <td>5</td>\n",
       "      <td>874965518</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>156</td>\n",
       "      <td>4</td>\n",
       "      <td>874965556</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>196</td>\n",
       "      <td>5</td>\n",
       "      <td>874965677</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  movie_id  rating  timestamp  num_watched\n",
       "0        1       168       5  874965478            1\n",
       "1        1       172       5  874965478            2\n",
       "2        1       165       5  874965518            3\n",
       "3        1       156       4  874965556            4\n",
       "4        1       196       5  874965677            5"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# adding a column with the number of movies watched per user\n",
    "dataset = data.sort_values([\"user_id\", \"timestamp\"]).reset_index(drop=True)\n",
    "dataset[\"one\"] = 1\n",
    "dataset[\"num_watched\"] = dataset.groupby(\"user_id\")[\"one\"].cumsum()\n",
    "dataset.drop(\"one\", axis=1, inplace=True)\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cf7c5da2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>movie_id</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>num_watched</th>\n",
       "      <th>mean_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>168</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>1</td>\n",
       "      <td>5.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>172</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>2</td>\n",
       "      <td>5.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>165</td>\n",
       "      <td>5</td>\n",
       "      <td>874965518</td>\n",
       "      <td>3</td>\n",
       "      <td>5.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>156</td>\n",
       "      <td>4</td>\n",
       "      <td>874965556</td>\n",
       "      <td>4</td>\n",
       "      <td>4.75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>196</td>\n",
       "      <td>5</td>\n",
       "      <td>874965677</td>\n",
       "      <td>5</td>\n",
       "      <td>4.80</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  movie_id  rating  timestamp  num_watched  mean_rate\n",
       "0        1       168       5  874965478            1       5.00\n",
       "1        1       172       5  874965478            2       5.00\n",
       "2        1       165       5  874965518            3       5.00\n",
       "3        1       156       4  874965556            4       4.75\n",
       "4        1       196       5  874965677            5       4.80"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# adding a column with the mean rating at a point in time per user\n",
    "dataset[\"mean_rate\"] = (\n",
    "    dataset.groupby(\"user_id\")[\"rating\"].cumsum() / dataset[\"num_watched\"]\n",
    ")\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29d1c399",
   "metadata": {},
   "source": [
    "### Problem formulation\n",
    "\n",
    "In this particular exercise the problem is formulated as predicting the next movie that will be watched (in consequence the last interactions will be discarded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0e9d1315",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[\"target\"] = dataset.groupby(\"user_id\")[\"movie_id\"].shift(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b38bba10",
   "metadata": {},
   "source": [
    "Following the same processing used by the author in the before-mentioned Kaggle notebook, we build sequences of previous movies watched"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f001f2b4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>movie_id</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>num_watched</th>\n",
       "      <th>mean_rate</th>\n",
       "      <th>target</th>\n",
       "      <th>prev_movies</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>168</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>1</td>\n",
       "      <td>5.00</td>\n",
       "      <td>172.0</td>\n",
       "      <td>[168]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>172</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>2</td>\n",
       "      <td>5.00</td>\n",
       "      <td>165.0</td>\n",
       "      <td>[168, 172]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>165</td>\n",
       "      <td>5</td>\n",
       "      <td>874965518</td>\n",
       "      <td>3</td>\n",
       "      <td>5.00</td>\n",
       "      <td>156.0</td>\n",
       "      <td>[168, 172, 165]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>156</td>\n",
       "      <td>4</td>\n",
       "      <td>874965556</td>\n",
       "      <td>4</td>\n",
       "      <td>4.75</td>\n",
       "      <td>196.0</td>\n",
       "      <td>[168, 172, 165, 156]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>196</td>\n",
       "      <td>5</td>\n",
       "      <td>874965677</td>\n",
       "      <td>5</td>\n",
       "      <td>4.80</td>\n",
       "      <td>166.0</td>\n",
       "      <td>[168, 172, 165, 156, 196]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  movie_id  rating  timestamp  num_watched  mean_rate  target  \\\n",
       "0        1       168       5  874965478            1       5.00   172.0   \n",
       "1        1       172       5  874965478            2       5.00   165.0   \n",
       "2        1       165       5  874965518            3       5.00   156.0   \n",
       "3        1       156       4  874965556            4       4.75   196.0   \n",
       "4        1       196       5  874965677            5       4.80   166.0   \n",
       "\n",
       "                 prev_movies  \n",
       "0                      [168]  \n",
       "1                 [168, 172]  \n",
       "2            [168, 172, 165]  \n",
       "3       [168, 172, 165, 156]  \n",
       "4  [168, 172, 165, 156, 196]  "
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Here the author builds the sequences\n",
    "dataset[\"prev_movies\"] = dataset[\"movie_id\"].apply(lambda x: str(x))\n",
    "dataset[\"prev_movies\"] = (\n",
    "    dataset.groupby(\"user_id\")[\"prev_movies\"]\n",
    "    .apply(lambda x: (x + \" \").cumsum().str.strip())\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "dataset[\"prev_movies\"] = dataset[\"prev_movies\"].apply(lambda x: x.split())\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a024b9c4",
   "metadata": {},
   "source": [
    "And now we add a `genre_rate` as the mean of all movies rated for a given genre per user\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5782f0c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>movie_id</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>num_watched</th>\n",
       "      <th>mean_rate</th>\n",
       "      <th>target</th>\n",
       "      <th>prev_movies</th>\n",
       "      <th>unknown</th>\n",
       "      <th>Action</th>\n",
       "      <th>...</th>\n",
       "      <th>Fantasy_rate</th>\n",
       "      <th>Film-Noir_rate</th>\n",
       "      <th>Horror_rate</th>\n",
       "      <th>Musical_rate</th>\n",
       "      <th>Mystery_rate</th>\n",
       "      <th>Romance_rate</th>\n",
       "      <th>Sci-Fi_rate</th>\n",
       "      <th>Thriller_rate</th>\n",
       "      <th>War_rate</th>\n",
       "      <th>Western_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>168</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>1</td>\n",
       "      <td>5.00</td>\n",
       "      <td>172.0</td>\n",
       "      <td>[168]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>172</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>2</td>\n",
       "      <td>5.00</td>\n",
       "      <td>165.0</td>\n",
       "      <td>[168, 172]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>165</td>\n",
       "      <td>5</td>\n",
       "      <td>874965518</td>\n",
       "      <td>3</td>\n",
       "      <td>5.00</td>\n",
       "      <td>156.0</td>\n",
       "      <td>[168, 172, 165]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>156</td>\n",
       "      <td>4</td>\n",
       "      <td>874965556</td>\n",
       "      <td>4</td>\n",
       "      <td>4.75</td>\n",
       "      <td>196.0</td>\n",
       "      <td>[168, 172, 165, 156]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>196</td>\n",
       "      <td>5</td>\n",
       "      <td>874965677</td>\n",
       "      <td>5</td>\n",
       "      <td>4.80</td>\n",
       "      <td>166.0</td>\n",
       "      <td>[168, 172, 165, 156, 196]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 46 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  movie_id  rating  timestamp  num_watched  mean_rate  target  \\\n",
       "0        1       168       5  874965478            1       5.00   172.0   \n",
       "1        1       172       5  874965478            2       5.00   165.0   \n",
       "2        1       165       5  874965518            3       5.00   156.0   \n",
       "3        1       156       4  874965556            4       4.75   196.0   \n",
       "4        1       196       5  874965677            5       4.80   166.0   \n",
       "\n",
       "                 prev_movies  unknown    Action  ...  Fantasy_rate  \\\n",
       "0                      [168]      0.0  0.000000  ...           NaN   \n",
       "1                 [168, 172]      0.0  0.500000  ...           NaN   \n",
       "2            [168, 172, 165]      0.0  0.333333  ...           NaN   \n",
       "3       [168, 172, 165, 156]      0.0  0.250000  ...           NaN   \n",
       "4  [168, 172, 165, 156, 196]      0.0  0.200000  ...           NaN   \n",
       "\n",
       "   Film-Noir_rate  Horror_rate  Musical_rate  Mystery_rate  Romance_rate  \\\n",
       "0             NaN          NaN           NaN           NaN           NaN   \n",
       "1             NaN          NaN           NaN           NaN           5.0   \n",
       "2             NaN          NaN           NaN           NaN           5.0   \n",
       "3             NaN          NaN           NaN           NaN           5.0   \n",
       "4             NaN          NaN           NaN           NaN           5.0   \n",
       "\n",
       "   Sci-Fi_rate  Thriller_rate  War_rate  Western_rate  \n",
       "0          NaN            NaN       NaN           NaN  \n",
       "1          5.0            NaN       5.0           NaN  \n",
       "2          5.0            NaN       5.0           NaN  \n",
       "3          5.0            4.0       5.0           NaN  \n",
       "4          5.0            4.0       5.0           NaN  \n",
       "\n",
       "[5 rows x 46 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = dataset.merge(items[[\"movie_id\"] + list_of_genres], on=\"movie_id\", how=\"left\")\n",
    "for genre in list_of_genres:\n",
    "    dataset[f\"{genre}_rate\"] = dataset[genre] * dataset[\"rating\"]\n",
    "    dataset[genre] = dataset.groupby(\"user_id\")[genre].cumsum()\n",
    "    dataset[f\"{genre}_rate\"] = (\n",
    "        dataset.groupby(\"user_id\")[f\"{genre}_rate\"].cumsum() / dataset[genre]\n",
    "    )\n",
    "dataset[list_of_genres] = dataset[list_of_genres].apply(\n",
    "    lambda x: x / dataset[\"num_watched\"]\n",
    ")\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7029510d",
   "metadata": {},
   "source": [
    "Adding user features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "df698ec8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>movie_id</th>\n",
       "      <th>rating</th>\n",
       "      <th>timestamp</th>\n",
       "      <th>num_watched</th>\n",
       "      <th>mean_rate</th>\n",
       "      <th>target</th>\n",
       "      <th>prev_movies</th>\n",
       "      <th>unknown</th>\n",
       "      <th>Action</th>\n",
       "      <th>...</th>\n",
       "      <th>Mystery_rate</th>\n",
       "      <th>Romance_rate</th>\n",
       "      <th>Sci-Fi_rate</th>\n",
       "      <th>Thriller_rate</th>\n",
       "      <th>War_rate</th>\n",
       "      <th>Western_rate</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>occupation</th>\n",
       "      <th>zip_code</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>168</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>1</td>\n",
       "      <td>5.00</td>\n",
       "      <td>172.0</td>\n",
       "      <td>[168]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>85711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>172</td>\n",
       "      <td>5</td>\n",
       "      <td>874965478</td>\n",
       "      <td>2</td>\n",
       "      <td>5.00</td>\n",
       "      <td>165.0</td>\n",
       "      <td>[168, 172]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>85711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>165</td>\n",
       "      <td>5</td>\n",
       "      <td>874965518</td>\n",
       "      <td>3</td>\n",
       "      <td>5.00</td>\n",
       "      <td>156.0</td>\n",
       "      <td>[168, 172, 165]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>85711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>156</td>\n",
       "      <td>4</td>\n",
       "      <td>874965556</td>\n",
       "      <td>4</td>\n",
       "      <td>4.75</td>\n",
       "      <td>196.0</td>\n",
       "      <td>[168, 172, 165, 156]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>85711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>196</td>\n",
       "      <td>5</td>\n",
       "      <td>874965677</td>\n",
       "      <td>5</td>\n",
       "      <td>4.80</td>\n",
       "      <td>166.0</td>\n",
       "      <td>[168, 172, 165, 156, 196]</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>24</td>\n",
       "      <td>M</td>\n",
       "      <td>technician</td>\n",
       "      <td>85711</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 50 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  movie_id  rating  timestamp  num_watched  mean_rate  target  \\\n",
       "0        1       168       5  874965478            1       5.00   172.0   \n",
       "1        1       172       5  874965478            2       5.00   165.0   \n",
       "2        1       165       5  874965518            3       5.00   156.0   \n",
       "3        1       156       4  874965556            4       4.75   196.0   \n",
       "4        1       196       5  874965677            5       4.80   166.0   \n",
       "\n",
       "                 prev_movies  unknown    Action  ...  Mystery_rate  \\\n",
       "0                      [168]      0.0  0.000000  ...           NaN   \n",
       "1                 [168, 172]      0.0  0.500000  ...           NaN   \n",
       "2            [168, 172, 165]      0.0  0.333333  ...           NaN   \n",
       "3       [168, 172, 165, 156]      0.0  0.250000  ...           NaN   \n",
       "4  [168, 172, 165, 156, 196]      0.0  0.200000  ...           NaN   \n",
       "\n",
       "   Romance_rate  Sci-Fi_rate  Thriller_rate  War_rate  Western_rate  age  \\\n",
       "0           NaN          NaN            NaN       NaN           NaN   24   \n",
       "1           5.0          5.0            NaN       5.0           NaN   24   \n",
       "2           5.0          5.0            NaN       5.0           NaN   24   \n",
       "3           5.0          5.0            4.0       5.0           NaN   24   \n",
       "4           5.0          5.0            4.0       5.0           NaN   24   \n",
       "\n",
       "   gender  occupation  zip_code  \n",
       "0       M  technician     85711  \n",
       "1       M  technician     85711  \n",
       "2       M  technician     85711  \n",
       "3       M  technician     85711  \n",
       "4       M  technician     85711  \n",
       "\n",
       "[5 rows x 50 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = dataset.merge(users, on=\"user_id\", how=\"left\")\n",
    "dataset.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee62d77e",
   "metadata": {},
   "source": [
    "Again, we use the same settings as those in the Kaggle notebook, but `COLD_START_TRESH` is pretty aggressive"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8060cf59",
   "metadata": {},
   "outputs": [],
   "source": [
    "COLD_START_TRESH = 5\n",
    "\n",
    "filtred_data = dataset[\n",
    "    (dataset[\"num_watched\"] >= COLD_START_TRESH) & ~(dataset[\"target\"].isna())\n",
    "].sort_values(\"timestamp\")\n",
    "train_data, _test_data = train_test_split(filtred_data, test_size=0.2, shuffle=False)\n",
    "valid_data, test_data = train_test_split(_test_data, test_size=0.5, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b1beb347",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols_to_drop = [\n",
    "    # \"rating\",\n",
    "    \"timestamp\",\n",
    "    \"num_watched\",\n",
    "]\n",
    "\n",
    "df_train = train_data.drop(cols_to_drop, axis=1)\n",
    "df_valid = valid_data.drop(cols_to_drop, axis=1)\n",
    "df_test = test_data.drop(cols_to_drop, axis=1)\n",
    "\n",
    "df_train.to_pickle(save_path / \"df_train.pkl\")\n",
    "df_valid.to_pickle(save_path / \"df_valid.pkl\")\n",
    "df_test.to_pickle(save_path / \"df_test.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bf71a82",
   "metadata": {},
   "source": [
    "Let's now build a model that is nearly identical to the one use in the[ Kaggle notebook](https://www.kaggle.com/code/matanivanov/wide-deep-learning-for-recsys-with-pytorch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6aa2e3f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from scipy.sparse import coo_matrix\n",
    "\n",
    "from pytorch_widedeep import Trainer\n",
    "from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep\n",
    "from pytorch_widedeep.preprocessing import TabPreprocessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "42b0d88f",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "save_path = Path(\"prepared_data\")\n",
    "\n",
    "PAD_IDX = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be204fe8",
   "metadata": {},
   "source": [
    "Let's use some of the functions the author of the kaggle's notebook uses to prepare the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "206eb90e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_coo_indexes(lil):\n",
    "    rows = []\n",
    "    cols = []\n",
    "    for i, el in enumerate(lil):\n",
    "        if type(el) != list:\n",
    "            el = [el]\n",
    "        for j in el:\n",
    "            rows.append(i)\n",
    "            cols.append(j)\n",
    "    return rows, cols\n",
    "\n",
    "\n",
    "def get_sparse_features(series, shape):\n",
    "    coo_indexes = get_coo_indexes(series.tolist())\n",
    "    sparse_df = coo_matrix(\n",
    "        (np.ones(len(coo_indexes[0])), (coo_indexes[0], coo_indexes[1])), shape=shape\n",
    "    )\n",
    "    return sparse_df\n",
    "\n",
    "\n",
    "def sparse_to_idx(data, pad_idx=-1):\n",
    "    indexes = data.nonzero()\n",
    "    indexes_df = pd.DataFrame()\n",
    "    indexes_df[\"rows\"] = indexes[0]\n",
    "    indexes_df[\"cols\"] = indexes[1]\n",
    "    mdf = indexes_df.groupby(\"rows\").apply(lambda x: x[\"cols\"].tolist())\n",
    "    max_len = mdf.apply(lambda x: len(x)).max()\n",
    "    return mdf.apply(lambda x: pd.Series(x + [pad_idx] * (max_len - len(x)))).values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ca8dd42",
   "metadata": {},
   "source": [
    "For the time being, we will not use a validation set for hyperparameter optimization, and we will simply concatenate the validation and the test set in one test set. I simply splitted the data into train/valid/test in case the reader wants to actually do hyperparameter optimization (and because I know in the future I will).\n",
    "\n",
    "There is also another caveat worth mentioning, related to the indexing of the movies. To build the matrices of movies watched, we use the entire dataset. A more realistic (and correct) approach would be to use ONLY the movies that appear in the training set and consider `unknown` or `unseen` those in the testing set that have not been seen during training. Nonetheless, this will not affect the purposes of this notebook, which is to illustrate how one could use `pytorch-widedeep` to build a recommendation algorithm. However, if one wanted to explore the performance of different algorithms in a \"proper\" way, these \"details\" need to be accounted for."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "39f778bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test = pd.concat([df_valid, df_test], ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ab7483c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_cols = [\"user_id\", \"movie_id\"]\n",
    "max_movie_index = max(df_train.movie_id.max(), df_test.movie_id.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3d17bd3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = df_train.drop(id_cols + [\"rating\", \"prev_movies\", \"target\"], axis=1)\n",
    "y_train = np.array(df_train.target.values, dtype=\"int64\")\n",
    "train_movies_watched = get_sparse_features(\n",
    "    df_train[\"prev_movies\"], (len(df_train), max_movie_index + 1)\n",
    ")\n",
    "\n",
    "X_test = df_test.drop(id_cols + [\"rating\", \"prev_movies\", \"target\"], axis=1)\n",
    "y_test = np.array(df_test.target.values, dtype=\"int64\")\n",
    "test_movies_watched = get_sparse_features(\n",
    "    df_test[\"prev_movies\"], (len(df_test), max_movie_index + 1)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "511e95ed",
   "metadata": {},
   "source": [
    "let's have a look to the information in each dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "dd9e5ef3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mean_rate</th>\n",
       "      <th>unknown</th>\n",
       "      <th>Action</th>\n",
       "      <th>Adventure</th>\n",
       "      <th>Animation</th>\n",
       "      <th>Children's</th>\n",
       "      <th>Comedy</th>\n",
       "      <th>Crime</th>\n",
       "      <th>Documentary</th>\n",
       "      <th>Drama</th>\n",
       "      <th>...</th>\n",
       "      <th>Mystery_rate</th>\n",
       "      <th>Romance_rate</th>\n",
       "      <th>Sci-Fi_rate</th>\n",
       "      <th>Thriller_rate</th>\n",
       "      <th>War_rate</th>\n",
       "      <th>Western_rate</th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>occupation</th>\n",
       "      <th>zip_code</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>25423</th>\n",
       "      <td>4.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>21</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>48823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25425</th>\n",
       "      <td>4.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.285714</td>\n",
       "      <td>0.142857</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.428571</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.285714</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>21</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>48823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25424</th>\n",
       "      <td>4.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.000000</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>21</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>48823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25426</th>\n",
       "      <td>3.875000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>0.125000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.666667</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>21</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>48823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25427</th>\n",
       "      <td>3.888889</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.222222</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>4.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>3.666667</td>\n",
       "      <td>4.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>21</td>\n",
       "      <td>M</td>\n",
       "      <td>student</td>\n",
       "      <td>48823</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 43 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       mean_rate  unknown    Action  Adventure  Animation  Children's  \\\n",
       "25423   4.000000      0.0  0.400000   0.200000        0.0         0.0   \n",
       "25425   4.000000      0.0  0.285714   0.142857        0.0         0.0   \n",
       "25424   4.000000      0.0  0.333333   0.166667        0.0         0.0   \n",
       "25426   3.875000      0.0  0.250000   0.125000        0.0         0.0   \n",
       "25427   3.888889      0.0  0.222222   0.111111        0.0         0.0   \n",
       "\n",
       "         Comedy  Crime  Documentary     Drama  ...  Mystery_rate  \\\n",
       "25423  0.400000    0.0          0.0  0.200000  ...           NaN   \n",
       "25425  0.428571    0.0          0.0  0.285714  ...           NaN   \n",
       "25424  0.333333    0.0          0.0  0.333333  ...           NaN   \n",
       "25426  0.375000    0.0          0.0  0.250000  ...           NaN   \n",
       "25427  0.333333    0.0          0.0  0.333333  ...           NaN   \n",
       "\n",
       "       Romance_rate  Sci-Fi_rate  Thriller_rate  War_rate  Western_rate  age  \\\n",
       "25423           4.0          4.0       4.000000       4.0           NaN   21   \n",
       "25425           4.0          4.0       4.000000       4.0           NaN   21   \n",
       "25424           4.0          4.0       4.000000       4.0           NaN   21   \n",
       "25426           4.0          4.0       3.666667       4.0           NaN   21   \n",
       "25427           4.0          4.0       3.666667       4.0           NaN   21   \n",
       "\n",
       "       gender  occupation  zip_code  \n",
       "25423       M     student     48823  \n",
       "25425       M     student     48823  \n",
       "25424       M     student     48823  \n",
       "25426       M     student     48823  \n",
       "25427       M     student     48823  \n",
       "\n",
       "[5 rows x 43 columns]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "840e59a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([772, 288, 108, ..., 183, 432, 509])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "516d2fd5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<76228x1683 sparse matrix of type '<class 'numpy.float64'>'\n",
       "\twith 7957390 stored elements in COOrdinate format>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_movies_watched"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a4cba74d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['173', '185', '255', '286', '298']"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sorted(df_train.prev_movies.tolist()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a4f11af4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0, 0, 0, 0, 0]), array([173, 185, 255, 286, 298]))"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.where(train_movies_watched.todense()[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d7dd7bc",
   "metadata": {},
   "source": [
    "And from now on is when the specifics related to this library start to appear. The only component that is going to be a bit different is the so-called tabular component, referred as `continuous` in the notebook. \n",
    "\n",
    "In the case of `pytorch-widedeep` we have the `TabPreprocessor` that allows for a lot of flexibility as to how we would like to process the tabular component of this Wide and Deep model. In other words, here our tabular component is a bit more elaborated than that in the notebook, just a bit...\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "733ea2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = [\"gender\", \"occupation\", \"zip_code\"]\n",
    "cont_cols = [c for c in X_train if c not in cat_cols]\n",
    "tab_preprocessor = TabPreprocessor(\n",
    "    cat_embed_cols=cat_cols,\n",
    "    continuous_cols=cont_cols,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "68555183",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_tab = tab_preprocessor.fit_transform(X_train.fillna(0))\n",
    "X_test_tab = tab_preprocessor.transform(X_test.fillna(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a00da28c",
   "metadata": {},
   "source": [
    "Now, in the notebook, the author moves the sparse matrices to sparse tensors and then turns them into dense tensors. In reality, this is not neccessary, one could feed sparse tensors to `nn.Linear` layers in pytorch. Nonetheless, this is not the most efficient implementation and is the reason why in our library the wide, linear component is implemented as an embedding layer. \n",
    "\n",
    "Nonetheless, to reproduce the notebook the best we can and because currently the `Wide` model in `pytorch-widedeep` is not designed to receive sparse tensors (we might consider implementing this functionality), we will turn the sparse COO matrices into dense arrays. We will then code a fairly simple, custom `Wide` component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "20903dd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_wide = np.array(train_movies_watched.todense())\n",
    "X_test_wide = np.array(test_movies_watched.todense())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "377e7f90",
   "metadata": {},
   "source": [
    "Finally, the author of the notebook uses a simple `Embedding` layer to encode the sequences of movies watched, the `prev_movies` columns. In my opinion, there is an element of information redundancy here. This is because the wide and text components have implicitely the same information, but in different form. Moreover, both of the models used for these two components ignore the sequential element in the data. Nonetheless, we want to reproduce the Kaggle notebook as close as possible, AND as one can explore later (by simply performing simple ablation studies), the wide component seems to carry most of the predictive power."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c52fd52c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_text = sparse_to_idx(train_movies_watched, pad_idx=PAD_IDX)\n",
    "X_test_text = sparse_to_idx(test_movies_watched, pad_idx=PAD_IDX)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ca8b84d",
   "metadata": {},
   "source": [
    "Let's now build the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "44bc73d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Wide(nn.Module):\n",
    "    def __init__(self, input_dim: int, pred_dim: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.input_dim = input_dim\n",
    "        self.pred_dim = pred_dim\n",
    "\n",
    "        # When I coded the library I never though that someone would want to code\n",
    "        # their own wide component. However, if you do, the wide component must have\n",
    "        # a 'wide_linear' attribute. In other words, the linear layer must be\n",
    "        # called 'wide_linear'\n",
    "        self.wide_linear = nn.Linear(input_dim, pred_dim)\n",
    "\n",
    "    def forward(self, X):\n",
    "        out = self.wide_linear(X.type(torch.float32))\n",
    "        return out\n",
    "\n",
    "\n",
    "wide = Wide(X_train_wide.shape[1], max_movie_index + 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "6f66130d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Wide(\n",
       "  (wide_linear): Linear(in_features=1683, out_features=1683, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "25592d30",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleEmbed(nn.Module):\n",
    "    def __init__(self, vocab_size: int, embed_dim: int, pad_idx: int):\n",
    "        super().__init__()\n",
    "\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_dim = embed_dim\n",
    "        self.pad_idx = pad_idx\n",
    "\n",
    "        # The sequences of movies watched are simply embedded in the Kaggle\n",
    "        # notebook. No RNN, Transformer or any model is used\n",
    "        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)\n",
    "\n",
    "    def forward(self, X):\n",
    "        embed = self.embed(X)\n",
    "        embed_mean = torch.mean(embed, dim=1)\n",
    "        return embed_mean\n",
    "\n",
    "    @property\n",
    "    def output_dim(self) -> int:\n",
    "        # All deep components in a custom 'pytorch-widedeep' model must have\n",
    "        # an output_dim property\n",
    "        return self.embed_dim\n",
    "\n",
    "\n",
    "#  In the notebook the author uses simply embeddings\n",
    "simple_embed = SimpleEmbed(max_movie_index + 1, 16, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "492f12c5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SimpleEmbed(\n",
       "  (embed): Embedding(1683, 16, padding_idx=0)\n",
       ")"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simple_embed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe9f137a",
   "metadata": {},
   "source": [
    "Maybe one would like to use an RNN to account for the sequence nature of the problem. If that was the case it would be as easy as: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "0c3f17b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "basic_rnn = BasicRNN(\n",
    "    vocab_size=max_movie_index + 1,\n",
    "    embed_dim=16,\n",
    "    hidden_dim=32,\n",
    "    n_layers=2,\n",
    "    rnn_type=\"gru\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e410d5d9",
   "metadata": {},
   "source": [
    "And finally, the tabular component, which is the notebook is simply a stak of linear + Rely layers. In our case we have an embedding layer before the linear layers to encode categorial and numerical cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "ca721555",
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_mlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    continuous_cols=tab_preprocessor.continuous_cols,\n",
    "    cont_norm_layer=None,\n",
    "    mlp_hidden_dims=[1024, 512, 256],\n",
    "    mlp_activation=\"relu\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "25c25e3a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TabMlp(\n",
       "  (cat_embed): DiffSizeCatEmbeddings(\n",
       "    (embed_layers): ModuleDict(\n",
       "      (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
       "      (emb_layer_occupation): Embedding(22, 9, padding_idx=0)\n",
       "      (emb_layer_zip_code): Embedding(648, 60, padding_idx=0)\n",
       "    )\n",
       "    (embedding_dropout): Dropout(p=0.0, inplace=False)\n",
       "  )\n",
       "  (cont_norm): Identity()\n",
       "  (encoder): MLP(\n",
       "    (mlp): Sequential(\n",
       "      (dense_layer_0): Sequential(\n",
       "        (0): Linear(in_features=111, out_features=1024, bias=True)\n",
       "        (1): ReLU(inplace=True)\n",
       "        (2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (dense_layer_1): Sequential(\n",
       "        (0): Linear(in_features=1024, out_features=512, bias=True)\n",
       "        (1): ReLU(inplace=True)\n",
       "        (2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (dense_layer_2): Sequential(\n",
       "        (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "        (1): ReLU(inplace=True)\n",
       "        (2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tab_mlp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b68c5bc9",
   "metadata": {},
   "source": [
    "Finally, we simply wrap up all models with the `WideDeep` 'collector' class and we are ready to train. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "4c6acc08",
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_deep_model = WideDeep(\n",
    "    wide=wide, deeptabular=tab_mlp, deeptext=simple_embed, pred_dim=max_movie_index + 1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "bc8970f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (wide): Wide(\n",
       "    (wide_linear): Linear(in_features=1683, out_features=1683, bias=True)\n",
       "  )\n",
       "  (deeptabular): Sequential(\n",
       "    (0): TabMlp(\n",
       "      (cat_embed): DiffSizeCatEmbeddings(\n",
       "        (embed_layers): ModuleDict(\n",
       "          (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n",
       "          (emb_layer_occupation): Embedding(22, 9, padding_idx=0)\n",
       "          (emb_layer_zip_code): Embedding(648, 60, padding_idx=0)\n",
       "        )\n",
       "        (embedding_dropout): Dropout(p=0.0, inplace=False)\n",
       "      )\n",
       "      (cont_norm): Identity()\n",
       "      (encoder): MLP(\n",
       "        (mlp): Sequential(\n",
       "          (dense_layer_0): Sequential(\n",
       "            (0): Linear(in_features=111, out_features=1024, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (dense_layer_1): Sequential(\n",
       "            (0): Linear(in_features=1024, out_features=512, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (dense_layer_2): Sequential(\n",
       "            (0): Linear(in_features=512, out_features=256, bias=True)\n",
       "            (1): ReLU(inplace=True)\n",
       "            (2): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=256, out_features=1683, bias=True)\n",
       "  )\n",
       "  (deeptext): Sequential(\n",
       "    (0): SimpleEmbed(\n",
       "      (embed): Embedding(1683, 16, padding_idx=0)\n",
       "    )\n",
       "    (1): Linear(in_features=16, out_features=1683, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wide_deep_model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e08d41ed",
   "metadata": {},
   "source": [
    "Note that the main difference between this wide and deep model and the Wide and Deep model in the Kaggle notebook is that in that notebook, the author concatenates the embedings and the tabular features, then passes this concatenation through a stack of linear + Relu layers with a final output dim of 256. Then concatenates this output with the binary features and connects this concatenation with the final linear layer (so the final weights are of dim (batch_size, 256 + 1683)). Our implementation follows the notation of the original paper and instead of concatenating the tabular, text and wide components and then connect them to the output neurons, we first compute their output, and then add it (see here: https://arxiv.org/pdf/1606.07792.pdf, their Eq 3). Note that this is effectively the same, with the caveat that while in one case one initialises a big weight matrix \"at once\", in our implementation we initialise different matrices for different components. Anyway, let's give it a go."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "538a34de",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=wide_deep_model,\n",
    "    objective=\"multiclass\",\n",
    "    custom_loss_function=nn.CrossEntropyLoss(ignore_index=PAD_IDX),\n",
    "    optimizers=torch.optim.Adam(wide_deep_model.parameters(), lr=1e-3),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "77c02ed5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:19<00:00,  7.66it/s, loss=6.66]\n",
      "valid: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 18.75it/s, loss=6.6]\n",
      "epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:21<00:00,  6.95it/s, loss=5.97]\n",
      "valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 21.03it/s, loss=6.52]\n",
      "epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:19<00:00,  7.51it/s, loss=5.65]\n",
      "valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 20.16it/s, loss=6.53]\n",
      "epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:23<00:00,  6.29it/s, loss=5.41]\n",
      "valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 13.97it/s, loss=6.57]\n",
      "epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████| 149/149 [00:19<00:00,  7.58it/s, loss=5.2]\n",
      "valid: 100%|████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 18.82it/s, loss=6.63]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    X_train={\n",
    "        \"X_wide\": X_train_wide,\n",
    "        \"X_tab\": X_train_tab,\n",
    "        \"X_text\": X_train_text,\n",
    "        \"target\": y_train,\n",
    "    },\n",
    "    X_val={\n",
    "        \"X_wide\": X_test_wide,\n",
    "        \"X_tab\": X_test_tab,\n",
    "        \"X_text\": X_test_text,\n",
    "        \"target\": y_test,\n",
    "    },\n",
    "    n_epochs=5,\n",
    "    batch_size=512,\n",
    "    shuffle=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8f9aec7",
   "metadata": {},
   "source": [
    "Now one could continue to the 'compare' metrics section of the Kaggle notebook. However, for the purposes of illustrating how one could use `pytorch-widedeep` to build recommendation algorithms we consider this notebook completed and move onto part 2 "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
